Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better tests for utils #465

Merged
merged 4 commits into from
Feb 13, 2023
Merged

Conversation

acforvs
Copy link
Contributor

@acforvs acforvs commented Dec 8, 2022

Hi, hopefully this one closes #404

optax/_src/utils_test.py Outdated Show resolved Hide resolved
optax/_src/utils_test.py Outdated Show resolved Hide resolved
optax/_src/utils_test.py Outdated Show resolved Hide resolved
optax/_src/utils_test.py Outdated Show resolved Hide resolved
optax/_src/utils_test.py Outdated Show resolved Hide resolved
@hbq1
Copy link
Collaborator

hbq1 commented Feb 8, 2023

Hi @acforvs, many thanks for your PR and apologies for the belated review!
It looks good to me; I left only a few comments. Could you also add tests for the rest of helpers from utils.py, so we could close the original issue?

@acforvs
Copy link
Contributor Author

acforvs commented Feb 8, 2023

Hi @hbq1 , thank you for the review!

Sure thing, I'll add the tests for the remaining functions, but I have some questions about them:

  1. The docstring for set_diags states that the shape of new_diags should be NxD. However, new_diags is only used in one place where it is flattened.

As a result, examples like

a = jnp.ones(shape=(3, 2, 2))
diags = jnp.array([[[3, 3, 3, 3, 3, 3]]])
print(a.shape, diags.shape)  # prints (3, 2, 2) (1, 1, 6)
optax._src.utils.set_diags(a, diags)

or

a = jnp.ones(shape=(3, 2, 2))
diags = jnp.array([[3, 3, 3], [3, 3, 3]])
print(a.shape, diags.shape)  # prints (3, 2, 2) (2, 3)
optax._src.utils.set_diags(a, diags)

work just fine

Is this by design or should we raise an error in case of a shape mismatch?
Currently, the function is used internally in two places: 1, 2. The comments specify that the dimension of the passed variable is indeed NxD there, so I'd suggest adding another assert. What are your thoughts?

  1. cast_tree uses .astype which wouldn't work for simple datatypes, such as
tree = dict(a=2.5, b=dict(c=-2.5))
tree = jax.tree_util.tree_map(lambda x : x, tree)
optax._src.utils.cast_tree(tree, int)

Would it be okay to only test this function for dicts of jnp.arrays of differerent dtypes? Something like

tree = dict(a=jnp.array(2.5), b=dict(c=jnp.array(-2.5)))
tree = jax.tree_util.tree_map(lambda x : x, tree)
optax._src.utils.cast_tree(tree, int)

@hbq1
Copy link
Collaborator

hbq1 commented Feb 9, 2023

Great, thank you!

  1. It is used in some internal code, so the correct behaviour would be to raise an exception in case of shape mismatch; it'd be great if you add this to the implementation!
  2. Testing only with trees of jnp.arrays would suffice.

@acforvs
Copy link
Contributor Author

acforvs commented Feb 9, 2023

I've fixed the errors and added tests for the rest of the utils.py functions.
I also changed assert to a ValueError in set_diags. Please let me know how it looks

Copy link
Collaborator

@hbq1 hbq1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes and for the nice PR 👍

@hbq1 hbq1 closed this Feb 10, 2023
@hbq1 hbq1 reopened this Feb 10, 2023
@copybara-service copybara-service bot merged commit dd075fa into google-deepmind:master Feb 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Better tests for utils.
2 participants